# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
"""Generates and prints out imports and constants for new TensorFlow python api."""
import argparse
import collections
import importlib
import os
import sys

from tensorflow.python.tools.api.generator import doc_srcs
from tensorflow.python.util import tf_decorator
from tensorflow.python.util import tf_export

API_ATTRS = tf_export.API_ATTRS
API_ATTRS_V1 = tf_export.API_ATTRS_V1

_LAZY_LOADING = False
_API_VERSIONS = [1, 2]
_COMPAT_MODULE_TEMPLATE = 'compat.v%d'
_SUBCOMPAT_MODULE_TEMPLATE = 'compat.v%d.compat.v%d'
_COMPAT_MODULE_PREFIX = 'compat.v'
_DEFAULT_PACKAGE = 'tensorflow.python'
_GENFILES_DIR_SUFFIX = 'genfiles/'
_SYMBOLS_TO_SKIP_EXPLICITLY = {
    # Overrides __getattr__, so that unwrapping tf_decorator
    # would have side effects.
    'tensorflow.python.platform.flags.FLAGS'
}
_GENERATED_FILE_HEADER = """# This file is MACHINE GENERATED! Do not edit.
# Generated by: tensorflow/python/tools/api/generator/create_python_api.py script.
\"\"\"%s
\"\"\"

import sys as _sys

"""
_GENERATED_FILE_FOOTER = ''
_DEPRECATION_FOOTER = """
from tensorflow.python.util import module_wrapper as _module_wrapper

if not isinstance(_sys.modules[__name__], _module_wrapper.TFModuleWrapper):
  _sys.modules[__name__] = _module_wrapper.TFModuleWrapper(
      _sys.modules[__name__], "%s", public_apis=%s, deprecation=%s,
      has_lite=%s)
"""
_LAZY_LOADING_MODULE_TEXT_TEMPLATE = """
# Inform pytype that this module is dynamically populated (b/111239204).
_HAS_DYNAMIC_ATTRIBUTES = True
_PUBLIC_APIS = {
%s
}
"""


class SymbolExposedTwiceError(Exception):
  """Raised when different symbols are exported with the same name."""
  pass


def get_canonical_import(import_set):
  """Obtain one single import from a set of possible sources of a symbol.

  One symbol might come from multiple places as it is being imported and
  reexported. To simplify API changes, we always use the same import for the
  same module, and give preference based on higher priority and alphabetical
  ordering.

  Args:
    import_set: (set) Imports providing the same symbol. This is a set of tuples
      in the form (import, priority). We want to pick an import with highest
      priority.

  Returns:
    A module name to import
  """
  # We use the fact that list sorting is stable, so first we convert the set to
  # a sorted list of the names and then we resort this list to move elements
  # not in core tensorflow to the end.
  # Here we sort by priority (higher preferred) and then  alphabetically by
  # import string.
  import_list = sorted(
      import_set,
      key=lambda imp_and_priority: (-imp_and_priority[1], imp_and_priority[0]))
  return import_list[0][0]


class _ModuleInitCodeBuilder(object):
  """Builds a map from module name to imports included in that module."""

  def __init__(self,
               output_package,
               api_version,
               lazy_loading=_LAZY_LOADING,
               use_relative_imports=False):
    self._output_package = output_package
    # Maps API module to API symbol name to set of tuples of the form
    # (module name, priority).
    # The same symbol can be imported from multiple locations. Higher
    # "priority" indicates that import location is preferred over others.
    self._module_imports = collections.defaultdict(
        lambda: collections.defaultdict(set))
    self._dest_import_to_id = collections.defaultdict(int)
    # Names that start with underscore in the root module.
    self._underscore_names_in_root = set()
    self._api_version = api_version
    # Controls whether or not exported symbols are lazily loaded or statically
    # imported.
    self._lazy_loading = lazy_loading
    self._use_relative_imports = use_relative_imports

  def _check_already_imported(self, symbol_id, api_name):
    if (api_name in self._dest_import_to_id and
        symbol_id != self._dest_import_to_id[api_name] and symbol_id != -1):
      raise SymbolExposedTwiceError(
          f'Trying to export multiple symbols with same name: {api_name}')
    self._dest_import_to_id[api_name] = symbol_id

  def add_import(self, symbol, source_module_name, source_name,
                 dest_module_name, dest_name):
    """Adds this import to module_imports.

    Args:
      symbol: TensorFlow Python symbol.
      source_module_name: (string) Module to import from.
      source_name: (string) Name of the symbol to import.
      dest_module_name: (string) Module name to add import to.
      dest_name: (string) Import the symbol using this name.

    Raises:
      SymbolExposedTwiceError: Raised when an import with the same
        dest_name has already been added to dest_module_name.
    """
    # modules_with_exports.py is only used during API generation and
    # won't be available when actually importing tensorflow.
    if source_module_name.endswith('python.modules_with_exports'):
      source_module_name = symbol.__module__
    import_str = self.format_import(source_module_name, source_name, dest_name)

    # Check if we are trying to expose two different symbols with same name.
    full_api_name = dest_name
    if dest_module_name:
      full_api_name = dest_module_name + '.' + full_api_name
    symbol_id = -1 if not symbol else id(symbol)
    self._check_already_imported(symbol_id, full_api_name)

    if not dest_module_name and dest_name.startswith('_'):
      self._underscore_names_in_root.add(dest_name)

    # The same symbol can be available in multiple modules.
    # We store all possible ways of importing this symbol and later pick just
    # one.
    priority = 0
    if symbol:
      # Give higher priority to source module if it matches
      # symbol's original module.
      if hasattr(symbol, '__module__'):
        priority = int(source_module_name == symbol.__module__)
      # Give higher priority if symbol name matches its __name__.
      if hasattr(symbol, '__name__'):
        priority += int(source_name == symbol.__name__)
    self._module_imports[dest_module_name][full_api_name].add(
        (import_str, priority))

  def _import_submodules(self):
    """Add imports for all destination modules in self._module_imports."""
    # Import all required modules in their parent modules.
    # For e.g. if we import 'foo.bar.Value'. Then, we also
    # import 'bar' in 'foo'.
    imported_modules = set(self._module_imports.keys())
    for module in imported_modules:
      if not module:
        continue
      module_split = module.split('.')
      parent_module = ''  # we import submodules in their parent_module

      for submodule_index in range(len(module_split)):
        if submodule_index > 0:
          submodule = module_split[submodule_index - 1]
          parent_module += '.' + submodule if parent_module else submodule
        import_from = self._output_package
        if self._lazy_loading:
          import_from += '.' + '.'.join(module_split[:submodule_index + 1])
          self.add_import(
              symbol=None,
              source_module_name='',
              source_name=import_from,
              dest_module_name=parent_module,
              dest_name=module_split[submodule_index])
        else:
          if self._use_relative_imports:
            import_from = '.'
          elif submodule_index > 0:
            import_from += '.' + '.'.join(module_split[:submodule_index])
          self.add_import(
              symbol=None,
              source_module_name=import_from,
              source_name=module_split[submodule_index],
              dest_module_name=parent_module,
              dest_name=module_split[submodule_index])

  def build(self):
    """Get a map from destination module to __init__.py code for that module.

    Returns:
      A dictionary where
        key: (string) destination module (for e.g. tf or tf.consts).
        value: (string) text that should be in __init__.py files for
          corresponding modules.
    """
    self._import_submodules()
    module_text_map = {}
    footer_text_map = {}
    for dest_module, dest_name_to_imports in self._module_imports.items():
      # Sort all possible imports for a symbol and pick the first one.
      imports_list = [
          get_canonical_import(imports)
          for _, imports in dest_name_to_imports.items()
      ]
      if self._lazy_loading:
        module_text_map[
            dest_module] = _LAZY_LOADING_MODULE_TEXT_TEMPLATE % '\n'.join(
                sorted(imports_list))
      else:
        module_text_map[dest_module] = '\n'.join(sorted(imports_list))

    # Expose exported symbols with underscores in root module since we import
    # from it using * import. Don't need this for lazy_loading because the
    # underscore symbols are already included in __all__ when passed in and
    # handled by TFModuleWrapper.
    root_module_footer = ''
    if not self._lazy_loading:
      underscore_names_str = ', '.join(
          '\'%s\'' % name for name in sorted(self._underscore_names_in_root))

      root_module_footer = """
_names_with_underscore = [%s]
__all__ = [_s for _s in dir() if not _s.startswith('_')]
__all__.extend([_s for _s in _names_with_underscore])
""" % underscore_names_str

    # Add module wrapper if we need to print deprecation messages
    # or if we use lazy loading.
    if self._api_version == 1 or self._lazy_loading:
      for dest_module, _ in self._module_imports.items():
        deprecation = 'False'
        has_lite = 'False'
        if self._api_version == 1:  # Add 1.* deprecations.
          if not dest_module.startswith(_COMPAT_MODULE_PREFIX):
            deprecation = 'True'
        # Workaround to make sure not load lite from lite/__init__.py
        if (not dest_module and 'lite' in self._module_imports and
            self._lazy_loading):
          has_lite = 'True'
        if self._lazy_loading:
          public_apis_name = '_PUBLIC_APIS'
        else:
          public_apis_name = 'None'
        footer_text_map[dest_module] = _DEPRECATION_FOOTER % (
            dest_module, public_apis_name, deprecation, has_lite)

    return module_text_map, footer_text_map, root_module_footer

  def format_import(self, source_module_name, source_name, dest_name):
    """Formats import statement.

    Args:
      source_module_name: (string) Source module to import from.
      source_name: (string) Source symbol name to import.
      dest_name: (string) Destination alias name.

    Returns:
      An import statement string.
    """
    if self._lazy_loading:
      return "  '%s': ('%s', '%s')," % (dest_name, source_module_name,
                                        source_name)
    else:
      if source_module_name:
        if source_name == dest_name:
          return 'from %s import %s' % (source_module_name, source_name)
        else:
          return 'from %s import %s as %s' % (source_module_name, source_name,
                                              dest_name)
      else:
        if source_name == dest_name:
          return 'import %s' % source_name
        else:
          return 'import %s as %s' % (source_name, dest_name)

  def get_destination_modules(self):
    return set(self._module_imports.keys())

  def copy_imports(self, from_dest_module, to_dest_module):
    self._module_imports[to_dest_module] = (
        self._module_imports[from_dest_module].copy())


def add_nested_compat_imports(module_builder, compat_api_versions,
                              output_package):
  """Adds compat.vN.compat.vK modules to module builder.

  To avoid circular imports, we want to add __init__.py files under
  compat.vN.compat.vK and under compat.vN.compat.vK.compat. For all other
  imports, we point to corresponding modules under compat.vK.

  Args:
    module_builder: `_ModuleInitCodeBuilder` instance.
    compat_api_versions: Supported compatibility versions.
    output_package: Base output python package where generated API will be
      added.
  """
  imported_modules = module_builder.get_destination_modules()

  # Copy over all imports in compat.vK to compat.vN.compat.vK and
  # all imports in compat.vK.compat to compat.vN.compat.vK.compat.
  for v in compat_api_versions:
    for sv in compat_api_versions:
      subcompat_module = _SUBCOMPAT_MODULE_TEMPLATE % (v, sv)
      compat_module = _COMPAT_MODULE_TEMPLATE % sv
      module_builder.copy_imports(compat_module, subcompat_module)
      module_builder.copy_imports('%s.compat' % compat_module,
                                  '%s.compat' % subcompat_module)

  # Prefixes of modules under compatibility packages, for e.g. "compat.v1.".
  compat_prefixes = tuple(
      _COMPAT_MODULE_TEMPLATE % v + '.' for v in compat_api_versions)

  # Above, we only copied function, class and constant imports. Here
  # we also add imports for child modules.
  for imported_module in imported_modules:
    if not imported_module.startswith(compat_prefixes):
      continue
    module_split = imported_module.split('.')

    # Handle compat.vN.compat.vK.compat.foo case. That is,
    # import compat.vK.compat.foo in compat.vN.compat.vK.compat.
    if len(module_split) > 3 and module_split[2] == 'compat':
      src_module = '.'.join(module_split[:3])
      src_name = module_split[3]
      assert src_name != 'v1' and src_name != 'v2', imported_module
    else:  # Handle compat.vN.compat.vK.foo case.
      src_module = '.'.join(module_split[:2])
      src_name = module_split[2]
      if src_name == 'compat':
        continue  # compat.vN.compat.vK.compat is handled separately

    for compat_api_version in compat_api_versions:
      module_builder.add_import(
          symbol=None,
          source_module_name='%s.%s' % (output_package, src_module),
          source_name=src_name,
          dest_module_name='compat.v%d.%s' % (compat_api_version, src_module),
          dest_name=src_name)


def _get_name_and_module(full_name):
  """Split full_name into module and short name.

  Args:
    full_name: Full name of symbol that includes module.

  Returns:
    Full module name and short symbol name.
  """
  name_segments = full_name.split('.')
  return '.'.join(name_segments[:-1]), name_segments[-1]


def _join_modules(module1, module2):
  """Concatenate 2 module components.

  Args:
    module1: First module to join.
    module2: Second module to join.

  Returns:
    Given two modules aaa.bbb and ccc.ddd, returns a joined
    module aaa.bbb.ccc.ddd.
  """
  if not module1:
    return module2
  if not module2:
    return module1
  return '%s.%s' % (module1, module2)


def add_imports_for_symbol(module_code_builder,
                           symbol,
                           source_module_name,
                           source_name,
                           api_name,
                           api_version,
                           output_module_prefix=''):
  """Add imports for the given symbol to `module_code_builder`.

  Args:
    module_code_builder: `_ModuleInitCodeBuilder` instance.
    symbol: A symbol.
    source_module_name: Module that we can import the symbol from.
    source_name: Name we can import the symbol with.
    api_name: API name. Currently, must be `tensorflow`.
    api_version: API version.
    output_module_prefix: Prefix to prepend to destination module.
  """
  if api_version == 1:
    names_attr = API_ATTRS_V1[api_name].names
    constants_attr = API_ATTRS_V1[api_name].constants
  else:
    names_attr = API_ATTRS[api_name].names
    constants_attr = API_ATTRS[api_name].constants

  # If symbol is _tf_api_constants attribute, then add the constants.
  if source_name == constants_attr:
    for exports, name in symbol:
      for export in exports:
        dest_module, dest_name = _get_name_and_module(export)
        dest_module = _join_modules(output_module_prefix, dest_module)
        module_code_builder.add_import(None, source_module_name, name,
                                       dest_module, dest_name)

  # If symbol has _tf_api_names attribute, then add import for it.
  if (hasattr(symbol, '__dict__') and names_attr in symbol.__dict__):

    # Generate import statements for symbols.
    for export in getattr(symbol, names_attr):  # pylint: disable=protected-access
      dest_module, dest_name = _get_name_and_module(export)
      dest_module = _join_modules(output_module_prefix, dest_module)
      module_code_builder.add_import(symbol, source_module_name, source_name,
                                     dest_module, dest_name)


def get_api_init_text(packages,
                      packages_to_ignore,
                      output_package,
                      api_name,
                      api_version,
                      compat_api_versions=None,
                      lazy_loading=_LAZY_LOADING,
                      use_relative_imports=False):
  """Get a map from destination module to __init__.py code for that module.

  Args:
    packages: Base python packages containing python with target tf_export
      decorators.
    packages_to_ignore: python packages to be ignored when checking for
      tf_export decorators.
    output_package: Base output python package where generated API will be
      added.
    api_name: API you want to generate Currently, only `tensorflow`.
    api_version: API version you want to generate (1 or 2).
    compat_api_versions: Additional API versions to generate under compat/
      directory.
    lazy_loading: Boolean flag. If True, a lazy loading `__init__.py` file is
      produced and if `False`, static imports are used.
    use_relative_imports: True if we should use relative imports when importing
      submodules.

  Returns:
    A dictionary where
      key: (string) destination module (for e.g. tf or tf.consts).
      value: (string) text that should be in __init__.py files for
        corresponding modules.
  """
  if compat_api_versions is None:
    compat_api_versions = []
  module_code_builder = _ModuleInitCodeBuilder(output_package, api_version,
                                               lazy_loading,
                                               use_relative_imports)

  # Traverse over everything imported above. Specifically,
  # we want to traverse over TensorFlow Python modules.

  def in_packages(m):
    return any(package in m for package in packages)

  for module in list(sys.modules.values()):
    # Only look at tensorflow modules.
    if (not module or not hasattr(module, '__name__') or
        module.__name__ is None or not in_packages(module.__name__)):
      continue
    if packages_to_ignore and any([p for p in packages_to_ignore
                                   if p in module.__name__]):
      continue

    # Do not generate __init__.py files for contrib modules for now.
    if (('.contrib.' in module.__name__ or module.__name__.endswith('.contrib'))
        and '.lite' not in module.__name__):
      continue

    for module_contents_name in dir(module):
      if (module.__name__ + '.' +
          module_contents_name in _SYMBOLS_TO_SKIP_EXPLICITLY):
        continue
      attr = getattr(module, module_contents_name)
      _, attr = tf_decorator.unwrap(attr)

      add_imports_for_symbol(module_code_builder, attr, module.__name__,
                             module_contents_name, api_name, api_version)
      for compat_api_version in compat_api_versions:
        add_imports_for_symbol(module_code_builder, attr, module.__name__,
                               module_contents_name, api_name,
                               compat_api_version,
                               _COMPAT_MODULE_TEMPLATE % compat_api_version)

  if compat_api_versions:
    add_nested_compat_imports(module_code_builder, compat_api_versions,
                              output_package)
  return module_code_builder.build()


def get_module(dir_path, relative_to_dir):
  """Get module that corresponds to path relative to relative_to_dir.

  Args:
    dir_path: Path to directory.
    relative_to_dir: Get module relative to this directory.

  Returns:
    Name of module that corresponds to the given directory.
  """
  dir_path = dir_path[len(relative_to_dir):]
  # Convert path separators to '/' for easier parsing below.
  dir_path = dir_path.replace(os.sep, '/')
  return dir_path.replace('/', '.').strip('.')


def get_module_docstring(module_name, package, api_name):
  """Get docstring for the given module.

  This method looks for docstring in the following order:
  1. Checks if module has a docstring specified in doc_srcs.
  2. Checks if module has a docstring source module specified
     in doc_srcs. If it does, gets docstring from that module.
  3. Checks if module with module_name exists under base package.
     If it does, gets docstring from that module.
  4. Returns a default docstring.

  Args:
    module_name: module name relative to tensorflow (excluding 'tensorflow.'
      prefix) to get a docstring for.
    package: Base python package containing python with target tf_export
      decorators.
    api_name: API you want to generate Currently, only `tensorflow`.

  Returns:
    One-line docstring to describe the module.
  """
  # Get the same module doc strings for any version. That is, for module
  # 'compat.v1.foo' we can get docstring from module 'foo'.
  for version in _API_VERSIONS:
    compat_prefix = _COMPAT_MODULE_TEMPLATE % version
    if module_name.startswith(compat_prefix):
      module_name = module_name[len(compat_prefix):].strip('.')

  # Module under base package to get a docstring from.
  docstring_module_name = module_name

  doc_sources = doc_srcs.get_doc_sources(api_name)

  if module_name in doc_sources:
    docsrc = doc_sources[module_name]
    if docsrc.docstring:
      return docsrc.docstring
    if docsrc.docstring_module_name:
      docstring_module_name = docsrc.docstring_module_name

  if package != 'tf_keras':
    docstring_module_name = package + '.' + docstring_module_name
  if (docstring_module_name in sys.modules and
      sys.modules[docstring_module_name].__doc__):
    return sys.modules[docstring_module_name].__doc__

  return 'Public API for tf.%s namespace.' % module_name


def create_primary_api_files(output_files,
                             packages,
                             packages_to_ignore,
                             root_init_template,
                             output_dir,
                             output_package,
                             api_name,
                             api_version,
                             compat_api_versions,
                             compat_init_templates,
                             lazy_loading=_LAZY_LOADING,
                             use_relative_imports=False):
  """Creates __init__.py files for the Python API.

  Args:
    output_files: List of __init__.py file paths to create.
    packages: Base python packages containing python with target tf_export
      decorators.
    packages_to_ignore: python packages to be ignored when checking for
      tf_export decorators.
    root_init_template: Template for top-level __init__.py file. "# API IMPORTS
      PLACEHOLDER" comment in the template file will be replaced with imports.
    output_dir: output API root directory.
    output_package: Base output package where generated API will be added.
    api_name: API you want to generate Currently, only `tensorflow`.
    api_version: API version to generate (`v1` or `v2`).
    compat_api_versions: Additional API versions to generate in compat/
      subdirectory.
    compat_init_templates: List of templates for top level compat init files in
      the same order as compat_api_versions.
    lazy_loading: Boolean flag. If True, a lazy loading `__init__.py` file is
      produced and if `False`, static imports are used.
    use_relative_imports: True if we should use relative imports when import
      submodules.

  Raises:
    ValueError: if output_files list is missing a required file.
  """
  module_name_to_file_path = {}
  for output_file in output_files:
    module_name = get_module(os.path.dirname(output_file), output_dir)
    module_name_to_file_path[module_name] = os.path.normpath(output_file)

  # Create file for each expected output in genrule.
  for module, file_path in module_name_to_file_path.items():
    if not os.path.isdir(os.path.dirname(file_path)):
      os.makedirs(os.path.dirname(file_path))
    open(file_path, 'a').close()

  (
      module_text_map,
      deprecation_footer_map,
      root_module_footer,
  ) = get_api_init_text(packages, packages_to_ignore, output_package, api_name,
                        api_version, compat_api_versions, lazy_loading,
                        use_relative_imports)

  # Add imports to output files.
  missing_output_files = []
  # Root modules are "" and "compat.v*".
  root_module = ''
  compat_module_to_template = {
      _COMPAT_MODULE_TEMPLATE % v: t
      for v, t in zip(compat_api_versions, compat_init_templates)
  }
  for v in compat_api_versions:
    compat_module_to_template.update({
        _SUBCOMPAT_MODULE_TEMPLATE % (v, vs): t
        for vs, t in zip(compat_api_versions, compat_init_templates)
    })

  for module, text in module_text_map.items():
    # Make sure genrule output file list is in sync with API exports.
    if module not in module_name_to_file_path:
      module_file_path = '"%s/__init__.py"' % (module.replace('.', '/'))
      missing_output_files.append(module_file_path)
      continue

    contents = ''
    if module == root_module and root_init_template:
      # Read base init file for root module
      with open(root_init_template, 'r') as root_init_template_file:
        contents = root_init_template_file.read()
        contents = contents.replace('# API IMPORTS PLACEHOLDER', text)
        contents = contents.replace('# __all__ PLACEHOLDER', root_module_footer)
    elif module in compat_module_to_template:
      # Read base init file for compat module
      with open(compat_module_to_template[module], 'r') as init_template_file:
        contents = init_template_file.read()
        contents = contents.replace('# API IMPORTS PLACEHOLDER', text)
    else:
      contents = (
          _GENERATED_FILE_HEADER %
          get_module_docstring(module, packages[0], api_name) + text +
          _GENERATED_FILE_FOOTER)
    if module in deprecation_footer_map:
      if '# WRAPPER_PLACEHOLDER' in contents:
        contents = contents.replace('# WRAPPER_PLACEHOLDER',
                                    deprecation_footer_map[module])
      else:
        contents += deprecation_footer_map[module]
    with open(module_name_to_file_path[module], 'w') as fp:
      fp.write(contents)

  if missing_output_files:
    missing_files = ',\n'.join(sorted(missing_output_files))
    raise ValueError(
        f'Missing outputs for genrule:\n{missing_files}. Be sure to add these '
        'targets to tensorflow/python/tools/api/generator/api_init_files_v1.bzl'
        ' and tensorflow/python/tools/api/generator/api_init_files.bzl '
        '(tensorflow repo), or tf_keras/api/api_init_files.bzl (tf_keras repo)')


def create_proxy_api_files(output_files,
                           proxy_module_root,
                           output_dir):
  """Creates __init__.py files in proxy format for the Python API.

  Args:
    output_files: List of __init__.py file paths to create.
    proxy_module_root: Module root for proxy-import format. If specified, proxy
      files with content like `from proxy_module_root.proxy_module import *`
      will be created to enable import resolution under TensorFlow.
    output_dir: output API root directory.
  """
  for file_path in output_files:
    module = get_module(os.path.dirname(file_path), output_dir)
    if not os.path.isdir(os.path.dirname(file_path)):
      os.makedirs(os.path.dirname(file_path))
    contents = f'from {proxy_module_root}.{module} import *'
    with open(file_path, 'w') as fp:
      fp.write(contents)


def main():
  parser = argparse.ArgumentParser()
  parser.add_argument(
      'outputs',
      metavar='O',
      type=str,
      nargs='+',
      help='If a single file is passed in, then we assume it contains a '
      'semicolon-separated list of Python files that we expect this script to '
      'output. If multiple files are passed in, then we assume output files '
      'are listed directly as arguments.')
  parser.add_argument(
      '--packages',
      default=_DEFAULT_PACKAGE,
      type=str,
      help='Base packages that import modules containing the target tf_export '
      'decorators.')
  parser.add_argument(
      '--packages_to_ignore',
      default='',
      type=str,
      help='Packages to exclude from the api generation. This is used to hide '
      'certain packages from this script when multiple copy of code exists, '
      'eg tf_keras. It is useful to avoid the SymbolExposedTwiceError.'
      )
  parser.add_argument(
      '--root_init_template',
      default='',
      type=str,
      help='Template for top level __init__.py file. '
      '"#API IMPORTS PLACEHOLDER" comment will be replaced with imports.')
  parser.add_argument(
      '--apidir',
      type=str,
      required=True,
      help='Directory where generated output files are placed. '
      'gendir should be a prefix of apidir. Also, apidir '
      'should be a prefix of every directory in outputs.')
  parser.add_argument(
      '--apiname',
      required=True,
      type=str,
      choices=API_ATTRS.keys(),
      help='The API you want to generate.')
  parser.add_argument(
      '--apiversion',
      default=2,
      type=int,
      choices=_API_VERSIONS,
      help='The API version you want to generate.')
  parser.add_argument(
      '--compat_apiversions',
      default=[],
      type=int,
      action='append',
      help='Additional versions to generate in compat/ subdirectory. '
      'If set to 0, then no additional version would be generated.')
  parser.add_argument(
      '--compat_init_templates',
      default=[],
      type=str,
      action='append',
      help='Templates for top-level __init__ files under compat modules. '
      'The list of init file templates must be in the same order as '
      'list of versions passed with compat_apiversions.')
  parser.add_argument(
      '--output_package',
      default='tensorflow',
      type=str,
      help='Root output package.')
  parser.add_argument(
      '--loading',
      default='default',
      type=str,
      choices=['lazy', 'static', 'default'],
      help='Controls how the generated __init__.py file loads the exported '
      'symbols. \'lazy\' means the symbols are loaded when first used. '
      '\'static\' means all exported symbols are loaded in the '
      '__init__.py file. \'default\' uses the value of the '
      '_LAZY_LOADING constant in create_python_api.py.')
  parser.add_argument(
      '--use_relative_imports',
      default=False,
      type=bool,
      help='Whether to import submodules using relative imports or absolute '
      'imports')
  parser.add_argument(
      '--proxy_module_root',
      default=None,
      type=str,
      help='Module root for proxy-import format. If specified, proxy files with '
      'content like `from proxy_module_root.proxy_module import *` will be '
      'created to enable import resolution under TensorFlow.')
  args = parser.parse_args()

  if len(args.outputs) == 1:
    # If we only get a single argument, then it must be a file containing
    # list of outputs.
    with open(args.outputs[0]) as output_list_file:
      outputs = [line.strip() for line in output_list_file.read().split(';')]
  else:
    outputs = args.outputs

  # Populate `sys.modules` with modules containing tf_export().
  packages = args.packages.split(',')
  for package in packages:
    importlib.import_module(package)
  packages_to_ignore = args.packages_to_ignore.split(',')

  # Determine if the modules shall be loaded lazily or statically.
  if args.loading == 'default':
    lazy_loading = _LAZY_LOADING
  elif args.loading == 'lazy':
    lazy_loading = True
  elif args.loading == 'static':
    lazy_loading = False
  else:
    # This should never happen (tm).
    raise ValueError(f'Invalid value for --loading flag: {args.loading}. Must '
                     'be one of lazy, static, default.')
  if args.proxy_module_root is None:
    create_primary_api_files(outputs, packages, packages_to_ignore,
                             args.root_init_template, args.apidir,
                             args.output_package, args.apiname, args.apiversion,
                             args.compat_apiversions,
                             args.compat_init_templates, lazy_loading,
                             args.use_relative_imports)
  else:
    create_proxy_api_files(outputs, args.proxy_module_root, args.apidir)


if __name__ == '__main__':
  main()
